# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# Changes Made from original:
#   import paths
#   modified arguments
#   validating arguments
# ******************************************************************************
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""TensorFlow NMT model implementation."""
from __future__ import print_function

import argparse
import os
import random
import sys

import numpy as np
import tensorflow as tf

from nlp_architect.utils import io
from . import inference
from . import train
from .gnmt.utils import misc_utils as utils, vocab_utils, evaluation_utils

utils.check_tensorflow_version()

FLAGS = None

INFERENCE_KEYS = [
    "src_max_len_infer",
    "tgt_max_len_infer",
    "subword_option",
    "infer_batch_size",
    "beam_width",
    "length_penalty_weight",
    "sampling_temperature",
    "num_translations_per_input",
    "infer_mode",
]


# pylint: disable=too-many-statements
def add_arguments(parser):
    """Build ArgumentParser."""
    parser.register("type", "bool", io.validate_boolean)

    # network
    parser.add_argument("--num_units", type=int, default=32, help="Network size.")
    parser.add_argument("--num_layers", type=int, default=2, help="Network depth.")
    parser.add_argument(
        "--num_encoder_layers",
        type=int,
        default=None,
        help="Encoder depth, equal to num_layers if None.",
    )
    parser.add_argument(
        "--num_decoder_layers",
        type=int,
        default=None,
        help="Decoder depth, equal to num_layers if None.",
    )
    parser.add_argument(
        "--encoder_type",
        type=str,
        default="uni",
        choices=["uni", "bi", "gnmt"],
        help="""\
      uni | bi | gnmt.
      For bi, we build num_encoder_layers/2 bi-directional layers.
      For gnmt, we build 1 bi-directional layer, and (num_encoder_layers - 1)
        uni-directional layers.\
      """,
    )
    parser.add_argument(
        "--residual",
        type="bool",
        nargs="?",
        const=True,
        default=False,
        help="Whether to add residual connections.",
    )
    parser.add_argument(
        "--time_major",
        type="bool",
        nargs="?",
        const=True,
        default=True,
        help="Whether to use time-major mode for dynamic RNN.",
    )
    parser.add_argument(
        "--num_embeddings_partitions",
        type=int,
        default=0,
        help="Number of partitions for embedding vars.",
    )

    # attention mechanisms
    parser.add_argument(
        "--attention",
        type=str,
        default="",
        choices=["", "luong", "scaled_luong", "bahdanau", "normed_bahdanau"],
        help="""\
      luong | scaled_luong | bahdanau | normed_bahdanau or set to "" for no
      attention\
      """,
    )
    parser.add_argument(
        "--attention_architecture",
        type=str,
        default="standard",
        choices=["standard", "gnmt", "gnmt_v2"],
        help="""\
      standard | gnmt | gnmt_v2.
      standard: use top layer to compute attention.
      gnmt: GNMT style of computing attention, use previous bottom layer to
          compute attention.
      gnmt_v2: similar to gnmt, but use current bottom layer to compute
          attention.\
      """,
    )
    parser.add_argument(
        "--output_attention",
        type="bool",
        nargs="?",
        const=True,
        default=True,
        help="""\
      Only used in standard attention_architecture. Whether use attention as
      the cell output at each timestep.
      .\
      """,
    )
    parser.add_argument(
        "--pass_hidden_state",
        type="bool",
        nargs="?",
        const=True,
        default=True,
        help="""\
      Whether to pass encoder's hidden state to decoder when using an attention
      based model.\
      """,
    )

    # optimizer
    parser.add_argument(
        "--optimizer", type=str, default="sgd", choices=["sgd", "adam"], help="sgd | adam"
    )
    parser.add_argument(
        "--learning_rate", type=float, default=1.0, help="Learning rate. Adam: 0.001 | 0.0001"
    )
    parser.add_argument(
        "--warmup_steps", type=int, default=0, help="How many steps we inverse-decay learning."
    )
    parser.add_argument(
        "--warmup_scheme",
        type=str,
        default="t2t",
        choices=["t2t"],
        help="""\
      How to warmup learning rates. Options include:
        t2t: Tensor2Tensor's way, start with lr 100 times smaller, then
             exponentiate until the specified lr.\
      """,
    )
    parser.add_argument(
        "--decay_scheme",
        type=str,
        default="",
        choices=["", "luong234", "luong5", "luong10"],
        help="""\
      How we decay learning rate. Options include:
        luong234: after 2/3 num train steps, we start halving the learning rate
          for 4 times before finishing.
        luong5: after 1/2 num train steps, we start halving the learning rate
          for 5 times before finishing.\
        luong10: after 1/2 num train steps, we start halving the learning rate
          for 10 times before finishing.\
      """,
    )

    parser.add_argument("--num_train_steps", type=int, default=12000, help="Num steps to train.")
    parser.add_argument(
        "--colocate_gradients_with_ops",
        type="bool",
        nargs="?",
        const=True,
        default=True,
        help=("Whether try colocating gradients with " "corresponding op"),
    )

    # initializer
    parser.add_argument(
        "--init_op",
        type=str,
        default="uniform",
        choices=["uniform", "glorot_normal", "glorot_uniform"],
        help="uniform | glorot_normal | glorot_uniform",
    )
    parser.add_argument(
        "--init_weight",
        type=float,
        default=0.1,
        help=("for uniform init_op, initialize weights " "between [-this, this]."),
    )

    # data
    parser.add_argument("--src", type=str, default=None, help="Source suffix, e.g., en.")
    parser.add_argument("--tgt", type=str, default=None, help="Target suffix, e.g., de.")
    parser.add_argument(
        "--train_prefix",
        type=str,
        default=None,
        help="Train prefix, expect files with src/tgt suffixes.",
    )
    parser.add_argument(
        "--dev_prefix",
        type=str,
        default=None,
        help="Dev prefix, expect files with src/tgt suffixes.",
    )
    parser.add_argument(
        "--test_prefix",
        type=str,
        default=None,
        help="Test prefix, expect files with src/tgt suffixes.",
    )
    parser.add_argument("--out_dir", type=str, default=None, help="Store log/model files.")

    # Vocab
    parser.add_argument(
        "--vocab_prefix",
        type=str,
        default=None,
        help="""\
      Vocab prefix, expect files with src/tgt suffixes.\
      """,
    )
    parser.add_argument(
        "--embed_prefix",
        type=str,
        default=None,
        help="""\
      Pretrained embedding prefix, expect files with src/tgt suffixes.
      The embedding files should be Glove formated txt files.\
      """,
    )
    parser.add_argument("--sos", type=str, default="<s>", help="Start-of-sentence symbol.")
    parser.add_argument("--eos", type=str, default="</s>", help="End-of-sentence symbol.")
    parser.add_argument(
        "--share_vocab",
        type="bool",
        nargs="?",
        const=True,
        default=False,
        help="""\
      Whether to use the source vocab and embeddings for both source and
      target.\
      """,
    )
    parser.add_argument(
        "--check_special_token",
        type="bool",
        default=True,
        help="""\
                      Whether check special sos, eos, unk tokens exist in the
                      vocab files.\
                      """,
    )

    # Sequence lengths
    parser.add_argument(
        "--src_max_len", type=int, default=50, help="Max length of src sequences during training."
    )
    parser.add_argument(
        "--tgt_max_len", type=int, default=50, help="Max length of tgt sequences during training."
    )
    parser.add_argument(
        "--src_max_len_infer",
        type=int,
        default=None,
        help="Max length of src sequences during inference.",
    )
    parser.add_argument(
        "--tgt_max_len_infer",
        type=int,
        default=None,
        help="""\
      Max length of tgt sequences during inference.  Also use to restrict the
      maximum decoding length.\
      """,
    )

    # Default settings works well (rarely need to change)
    parser.add_argument(
        "--unit_type",
        type=str,
        default="lstm",
        choices=["lstm", "gru", "layer_norm_lstm", "mlstm"],
        help="lstm | gru | layer_norm_lstm | nas | mlstm",
    )
    parser.add_argument(
        "--projection_type",
        type=str,
        default="dense",
        choices=["dense", "sparse"],
        help="dense | sparse",
    )
    parser.add_argument(
        "--embedding_type",
        type=str,
        default="dense",
        choices=["dense", "sparse"],
        help="dense | sparse",
    )
    parser.add_argument(
        "--forget_bias", type=float, default=1.0, help="Forget bias for BasicLSTMCell."
    )
    parser.add_argument("--dropout", type=float, default=0.2, help="Dropout rate (not keep_prob)")
    parser.add_argument(
        "--max_gradient_norm", type=float, default=5.0, help="Clip gradients to this norm."
    )
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size.")

    parser.add_argument(
        "--steps_per_stats",
        type=int,
        default=100,
        help=(
            "How many training steps to do per stats logging."
            "Save checkpoint every 10x steps_per_stats"
        ),
    )
    parser.add_argument(
        "--max_train", type=int, default=0, help="Limit on the size of training data (0: no limit)."
    )
    parser.add_argument(
        "--num_buckets", type=int, default=5, help="Put data into similar-length buckets."
    )
    parser.add_argument(
        "--num_sampled_softmax",
        type=int,
        default=0,
        help=("Use sampled_softmax_loss if > 0." "Otherwise, use full softmax loss."),
    )

    # SPM
    parser.add_argument(
        "--subword_option",
        type=str,
        default="",
        choices=["", "bpe", "spm"],
        help="""\
                      Set to bpe or spm to activate subword desegmentation.\
                      """,
    )

    # Experimental encoding feature.
    parser.add_argument(
        "--use_char_encode",
        type="bool",
        default=False,
        help="""\
                      Whether to split each word or bpe into character, and then
                      generate the word-level representation from the character
                      reprentation.
                      """,
    )

    # Misc
    parser.add_argument("--num_gpus", type=int, default=1, help="Number of gpus in each worker.")
    parser.add_argument(
        "--log_device_placement",
        type="bool",
        nargs="?",
        const=True,
        default=False,
        help="Debug GPU allocation.",
    )
    parser.add_argument(
        "--metrics",
        type=str,
        default="bleu",
        help=("Comma-separated list of evaluations " "metrics (bleu,rouge,accuracy)"),
    )
    parser.add_argument(
        "--steps_per_external_eval",
        type=int,
        default=None,
        help="""\
      How many training steps to do per external evaluation.  Automatically set
      based on data if None.\
      """,
    )
    parser.add_argument("--scope", type=str, default=None, help="scope to put variables under")
    parser.add_argument(
        "--hparams_path",
        type=str,
        default=None,
        help=("Path to standard hparams json file that overrides" "hparams values from FLAGS."),
    )
    parser.add_argument(
        "--random_seed", type=int, default=None, help="Random seed (>0, set a specific seed)."
    )
    parser.add_argument(
        "--override_loaded_hparams",
        type="bool",
        nargs="?",
        const=True,
        default=False,
        help="Override loaded hparams with values specified",
    )
    parser.add_argument(
        "--num_keep_ckpts", type=int, default=5, help="Max number of checkpoints to keep."
    )
    parser.add_argument(
        "--avg_ckpts",
        type="bool",
        nargs="?",
        const=True,
        default=False,
        help=(
            """\
                      Average the last N checkpoints for external evaluation.
                      N can be controlled by setting --num_keep_ckpts.\
                      """
        ),
    )
    parser.add_argument(
        "--language_model",
        type="bool",
        nargs="?",
        const=True,
        default=False,
        help="True to train a language model, ignoring encoder",
    )

    # Inference
    parser.add_argument(
        "--ckpt", type=str, default="", help="Checkpoint file to load a model for inference."
    )
    parser.add_argument(
        "--quantize_ckpt",
        type="bool",
        default=False,
        help="Set to True to produce a quantized checkpoint from existing " "checkpoint",
    )
    parser.add_argument(
        "--from_quantized_ckpt",
        type="bool",
        default=False,
        help="Set to True when the given checkpoint is quantized",
    )
    parser.add_argument(
        "--inference_input_file", type=str, default=None, help="Set to the text to decode."
    )
    parser.add_argument(
        "--inference_list",
        type=str,
        default=None,
        help=("A comma-separated list of sentence indices " "(0-based) to decode."),
    )
    parser.add_argument(
        "--infer_batch_size", type=int, default=32, help="Batch size for inference mode."
    )
    parser.add_argument(
        "--inference_output_file",
        type=str,
        default=None,
        help="Output file to store decoding results.",
    )
    parser.add_argument(
        "--inference_ref_file",
        type=str,
        default=None,
        help=(
            """\
      Reference file to compute evaluation scores (if provided).\
      """
        ),
    )

    # Advanced inference arguments
    parser.add_argument(
        "--infer_mode",
        type=str,
        default="greedy",
        choices=["greedy", "sample", "beam_search"],
        help="Which type of decoder to use during inference.",
    )
    parser.add_argument(
        "--beam_width",
        type=int,
        default=0,
        help=(
            """\
      beam width when using beam search decoder. If 0 (default), use standard
      decoder with greedy helper.\
      """
        ),
    )
    parser.add_argument(
        "--length_penalty_weight", type=float, default=0.0, help="Length penalty for beam search."
    )
    parser.add_argument(
        "--sampling_temperature",
        type=float,
        default=0.0,
        help=(
            """\
      Softmax sampling temperature for inference decoding, 0.0 means greedy
      decoding. This option is ignored when using beam search.\
      """
        ),
    )
    parser.add_argument(
        "--num_translations_per_input",
        type=int,
        default=1,
        help=(
            """\
      Number of translations generated for each sentence. This is only used for
      inference.\
      """
        ),
    )

    # Job info
    parser.add_argument("--jobid", type=int, default=0, help="Task id of the worker.")
    parser.add_argument(
        "--num_workers", type=int, default=1, help="Number of workers (inference only)."
    )
    parser.add_argument(
        "--num_inter_threads", type=int, default=0, help="number of inter_op_parallelism_threads"
    )
    parser.add_argument(
        "--num_intra_threads", type=int, default=0, help="number of intra_op_parallelism_threads"
    )
    parser.add_argument(
        "--pruning_hparams", type=str, default=None, help="model pruning parameters"
    )


def create_hparams(flags):
    """Create training hparams."""
    return tf.contrib.training.HParams(
        # Data
        src=flags.src,
        tgt=flags.tgt,
        train_prefix=flags.train_prefix,
        dev_prefix=flags.dev_prefix,
        test_prefix=flags.test_prefix,
        vocab_prefix=flags.vocab_prefix,
        embed_prefix=flags.embed_prefix,
        out_dir=flags.out_dir,
        # Networks
        num_units=flags.num_units,
        num_encoder_layers=(flags.num_encoder_layers or flags.num_layers),
        num_decoder_layers=(flags.num_decoder_layers or flags.num_layers),
        dropout=flags.dropout,
        unit_type=flags.unit_type,
        projection_type=flags.projection_type,
        embedding_type=flags.embedding_type,
        encoder_type=flags.encoder_type,
        residual=flags.residual,
        time_major=flags.time_major,
        num_embeddings_partitions=flags.num_embeddings_partitions,
        # Attention mechanisms
        attention=flags.attention,
        attention_architecture=flags.attention_architecture,
        output_attention=flags.output_attention,
        pass_hidden_state=flags.pass_hidden_state,
        # Train
        optimizer=flags.optimizer,
        num_train_steps=flags.num_train_steps,
        batch_size=flags.batch_size,
        init_op=flags.init_op,
        init_weight=flags.init_weight,
        max_gradient_norm=flags.max_gradient_norm,
        learning_rate=flags.learning_rate,
        warmup_steps=flags.warmup_steps,
        warmup_scheme=flags.warmup_scheme,
        decay_scheme=flags.decay_scheme,
        colocate_gradients_with_ops=flags.colocate_gradients_with_ops,
        num_sampled_softmax=flags.num_sampled_softmax,
        # Data constraints
        num_buckets=flags.num_buckets,
        max_train=flags.max_train,
        src_max_len=flags.src_max_len,
        tgt_max_len=flags.tgt_max_len,
        # Inference
        quantize_ckpt=flags.quantize_ckpt,
        from_quantized_ckpt=flags.from_quantized_ckpt,
        src_max_len_infer=flags.src_max_len_infer,
        tgt_max_len_infer=flags.tgt_max_len_infer,
        infer_batch_size=flags.infer_batch_size,
        # Advanced inference arguments
        infer_mode=flags.infer_mode,
        beam_width=flags.beam_width,
        length_penalty_weight=flags.length_penalty_weight,
        sampling_temperature=flags.sampling_temperature,
        num_translations_per_input=flags.num_translations_per_input,
        # Vocab
        sos=flags.sos if flags.sos else vocab_utils.SOS,
        eos=flags.eos if flags.eos else vocab_utils.EOS,
        subword_option=flags.subword_option,
        check_special_token=flags.check_special_token,
        use_char_encode=flags.use_char_encode,
        # Misc
        forget_bias=flags.forget_bias,
        num_gpus=flags.num_gpus,
        epoch_step=0,  # record where we were within an epoch.
        steps_per_stats=flags.steps_per_stats,
        steps_per_external_eval=flags.steps_per_external_eval,
        share_vocab=flags.share_vocab,
        metrics=flags.metrics.split(","),
        log_device_placement=flags.log_device_placement,
        random_seed=flags.random_seed,
        override_loaded_hparams=flags.override_loaded_hparams,
        num_keep_ckpts=flags.num_keep_ckpts,
        avg_ckpts=flags.avg_ckpts,
        language_model=flags.language_model,
        num_intra_threads=flags.num_intra_threads,
        num_inter_threads=flags.num_inter_threads,
        # Model pruning
        pruning_hparams=flags.pruning_hparams,
    )


def fix_path(file_path):
    """Fix file path for validation function"""
    return os.path.join(
        os.path.dirname(os.path.realpath(__file__)), os.path.realpath(str(file_path))
    )


def add_suffix(p, s):
    """Stich suffix and prefix"""
    s = "." + str(s) if s is not None else ""
    return str(p) + str(s)


def validate_existing_filepath(prefix, suffix=None):
    """Validates existing file in the path constructed from prefix.suffix in case
    prefix is not None"""
    if prefix is not None and prefix:
        io.validate_existing_filepath(fix_path(add_suffix(prefix, suffix)))


def validate_parent_exists(file_path):
    """Validates parent directory exists in case the file_path is not None"""
    if file_path is not None and file_path:
        io.validate_parent_exists(fix_path(file_path))


def validate_arguments(args):
    """Validate input arguments"""
    io.validate((args.num_units, int, 1, None))
    io.validate((args.num_layers, int, 1, None))
    io.validate((args.num_encoder_layers, (int, type(None)), 0, None))
    io.validate((args.num_decoder_layers, (int, type(None)), 0, None))
    io.validate((args.num_embeddings_partitions, int, 0, None))
    io.validate((args.learning_rate, float, 0.0, None))
    io.validate((args.num_train_steps, int, 1, None))
    io.validate((args.warmup_steps, int, 0, args.num_train_steps))
    io.validate((args.init_weight, float))
    io.validate((args.src, (str, type(None)), 1, 256))
    io.validate((args.tgt, (str, type(None)), 1, 256))
    io.validate((args.sos, str, 1, 256))
    io.validate((args.eos, str, 1, 256))
    io.validate((args.src_max_len, int, 1, None))
    io.validate((args.tgt_max_len, int, 1, None))
    io.validate((args.src_max_len_infer, (int, type(None)), 1, None))
    io.validate((args.tgt_max_len_infer, (int, type(None)), 1, None))
    io.validate((args.forget_bias, float, 0.0, None))
    io.validate((args.dropout, float, 0.0, 1.0))
    io.validate((args.max_gradient_norm, float, 0.000000001, None))
    io.validate((args.batch_size, int, 1, None))
    io.validate((args.steps_per_stats, int, 1, None))
    io.validate((args.max_train, int, 0, None))
    io.validate((args.num_buckets, int, 1, None))
    io.validate((args.num_sampled_softmax, int, 0, None))
    io.validate((args.num_gpus, int, 0, None))
    io.validate((args.metrics, str, 1, 256))
    io.validate((args.inference_list, (str, type(None)), 0, 256))
    io.validate((args.steps_per_external_eval, (int, type(None)), 1, None))
    io.validate((args.scope, (str, type(None)), 1, 256))
    io.validate((args.random_seed, (int, type(None))))
    io.validate((args.num_keep_ckpts, int, 0, None))
    io.validate((args.infer_batch_size, int, 1, None))
    io.validate((args.beam_width, int, 0, None))
    io.validate((args.length_penalty_weight, float, 0.0, None))
    io.validate((args.sampling_temperature, float, 0.0, None))
    io.validate((args.num_translations_per_input, int, 1, None))
    io.validate((args.jobid, int, 0, None))
    io.validate((args.num_workers, int, 1, None))
    io.validate((args.num_inter_threads, int, 0, None))
    io.validate((args.num_intra_threads, int, 0, None))
    io.validate((args.pruning_hparams, (str, type(None)), 1, 256))

    suffixes = [args.src]
    if not args.language_model:
        suffixes.append(args.tgt)

    for suffix in suffixes:
        validate_existing_filepath(args.train_prefix, suffix)
        validate_existing_filepath(args.dev_prefix, suffix)
        validate_existing_filepath(args.test_prefix, suffix)
        validate_existing_filepath(args.vocab_prefix, suffix)
        validate_existing_filepath(args.embed_prefix, suffix)
    validate_existing_filepath(args.inference_ref_file)
    validate_existing_filepath(args.inference_input_file)
    validate_existing_filepath(args.hparams_path)
    validate_parent_exists(args.ckpt)
    validate_parent_exists(args.inference_output_file)
    validate_parent_exists(args.out_dir)


def _add_argument(hparams, key, value, update=True):
    """Add an argument to hparams; if exists, change the value if update==True."""
    if hasattr(hparams, key):
        if update:
            setattr(hparams, key, value)
    else:
        hparams.add_hparam(key, value)


# pylint: disable=too-many-statements
def extend_hparams(hparams):
    """Add new arguments to hparams."""
    # Sanity checks
    if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0:
        raise ValueError(
            "For bi, num_encoder_layers %d should be even" % hparams.num_encoder_layers
        )
    if hparams.attention_architecture in ["gnmt"] and hparams.num_encoder_layers < 2:
        raise ValueError(
            "For gnmt attention architecture, "
            "num_encoder_layers %d should be >= 2" % hparams.num_encoder_layers
        )
    if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]:
        raise ValueError("subword option must be either spm, or bpe")
    if hparams.infer_mode == "beam_search" and hparams.beam_width <= 0:
        raise ValueError("beam_width must greater than 0 when using beam_search" "decoder.")
    if hparams.infer_mode == "sample" and hparams.sampling_temperature <= 0.0:
        raise ValueError("sampling_temperature must greater than 0.0 when using" "sample decoder.")

    # Different number of encoder / decoder layers
    assert hparams.num_encoder_layers and hparams.num_decoder_layers
    if hparams.num_encoder_layers != hparams.num_decoder_layers:
        hparams.pass_hidden_state = False
        utils.print_out(
            "Num encoder layer %d is different from num decoder layer"
            " %d, so set pass_hidden_state to False"
            % (hparams.num_encoder_layers, hparams.num_decoder_layers)
        )

    # Set residual layers
    num_encoder_residual_layers = 0
    num_decoder_residual_layers = 0
    if hparams.residual:
        if hparams.num_encoder_layers > 1:
            num_encoder_residual_layers = hparams.num_encoder_layers - 1
        if hparams.num_decoder_layers > 1:
            num_decoder_residual_layers = hparams.num_decoder_layers - 1

        if hparams.encoder_type == "gnmt":
            # The first unidirectional layer (after the bi-directional layer) in
            # the GNMT encoder can't have residual connection due to the input is
            # the concatenation of fw_cell and bw_cell's outputs.
            num_encoder_residual_layers = hparams.num_encoder_layers - 2

            # Compatible for GNMT models
            if hparams.num_encoder_layers == hparams.num_decoder_layers:
                num_decoder_residual_layers = num_encoder_residual_layers
    _add_argument(hparams, "num_encoder_residual_layers", num_encoder_residual_layers)
    _add_argument(hparams, "num_decoder_residual_layers", num_decoder_residual_layers)

    # Language modeling
    if getattr(hparams, "language_model", None):
        hparams.attention = ""
        hparams.attention_architecture = ""
        hparams.pass_hidden_state = False
        hparams.share_vocab = True
        hparams.src = hparams.tgt
        utils.print_out(
            "For language modeling, we turn off attention and "
            "pass_hidden_state; turn on share_vocab; set src to tgt."
        )

    # Vocab
    # Get vocab file names first
    if hparams.vocab_prefix:
        src_vocab_file = hparams.vocab_prefix + "." + hparams.src
        tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt
    else:
        raise ValueError("hparams.vocab_prefix must be provided.")

    # Source vocab
    check_special_token = getattr(hparams, "check_special_token", True)
    src_vocab_size, src_vocab_file = vocab_utils.check_vocab(
        src_vocab_file,
        hparams.out_dir,
        check_special_token=check_special_token,
        sos=hparams.sos,
        eos=hparams.eos,
        unk=vocab_utils.UNK,
    )

    # Target vocab
    if hparams.share_vocab:
        utils.print_out("  using source vocab for target")
        tgt_vocab_file = src_vocab_file
        tgt_vocab_size = src_vocab_size
    else:
        tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(
            tgt_vocab_file,
            hparams.out_dir,
            check_special_token=check_special_token,
            sos=hparams.sos,
            eos=hparams.eos,
            unk=vocab_utils.UNK,
        )
    _add_argument(hparams, "src_vocab_size", src_vocab_size)
    _add_argument(hparams, "tgt_vocab_size", tgt_vocab_size)
    _add_argument(hparams, "src_vocab_file", src_vocab_file)
    _add_argument(hparams, "tgt_vocab_file", tgt_vocab_file)

    # Num embedding partitions
    num_embeddings_partitions = getattr(hparams, "num_embeddings_partitions", 0)
    _add_argument(hparams, "num_enc_emb_partitions", num_embeddings_partitions)
    _add_argument(hparams, "num_dec_emb_partitions", num_embeddings_partitions)

    # Pretrained Embeddings
    _add_argument(hparams, "src_embed_file", "")
    _add_argument(hparams, "tgt_embed_file", "")
    if getattr(hparams, "embed_prefix", None):
        src_embed_file = hparams.embed_prefix + "." + hparams.src
        tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt

        if tf.gfile.Exists(src_embed_file):
            utils.print_out("  src_embed_file %s exist" % src_embed_file)
            hparams.src_embed_file = src_embed_file

            utils.print_out("For pretrained embeddings, set num_enc_emb_partitions to 1")
            hparams.num_enc_emb_partitions = 1
        else:
            utils.print_out("  src_embed_file %s doesn't exist" % src_embed_file)

        if tf.gfile.Exists(tgt_embed_file):
            utils.print_out("  tgt_embed_file %s exist" % tgt_embed_file)
            hparams.tgt_embed_file = tgt_embed_file

            utils.print_out("For pretrained embeddings, set num_dec_emb_partitions to 1")
            hparams.num_dec_emb_partitions = 1
        else:
            utils.print_out("  tgt_embed_file %s doesn't exist" % tgt_embed_file)

    # Evaluation
    for metric in hparams.metrics:
        best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric)
        tf.gfile.MakeDirs(best_metric_dir)
        _add_argument(hparams, "best_" + metric, 0, update=False)
        _add_argument(hparams, "best_" + metric + "_dir", best_metric_dir)

        if getattr(hparams, "avg_ckpts", None):
            best_metric_dir = os.path.join(hparams.out_dir, "avg_best_" + metric)
            tf.gfile.MakeDirs(best_metric_dir)
            _add_argument(hparams, "avg_best_" + metric, 0, update=False)
            _add_argument(hparams, "avg_best_" + metric + "_dir", best_metric_dir)

    return hparams


def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(default_hparams, hparams_path)

    # Set num encoder/decoder layers (for old checkpoints)
    if hasattr(hparams, "num_layers"):
        if not hasattr(hparams, "num_encoder_layers"):
            hparams.add_hparam("num_encoder_layers", hparams.num_layers)
        if not hasattr(hparams, "num_decoder_layers"):
            hparams.add_hparam("num_decoder_layers", hparams.num_layers)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Update all hparams' keys if override_loaded_hparams=True
    if getattr(default_hparams, "override_loaded_hparams", None):
        overwritten_keys = default_config.keys()
    else:
        # For inference
        overwritten_keys = INFERENCE_KEYS

    for key in overwritten_keys:
        if getattr(hparams, key) != default_config[key]:
            utils.print_out(
                "# Updating hparams.%s: %s -> %s"
                % (key, str(getattr(hparams, key)), str(default_config[key]))
            )
            setattr(hparams, key, default_config[key])
    return hparams


def create_or_load_hparams(out_dir, default_hparams, hparams_path, save_hparams=True):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams, hparams_path)
    hparams = extend_hparams(hparams)

    # Save HParams
    if save_hparams:
        utils.save_hparams(out_dir, hparams)
        for metric in hparams.metrics:
            utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"), hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams


def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""):
    """Run main."""
    # Job
    jobid = flags.jobid
    num_workers = flags.num_workers
    utils.print_out("# Job id %d" % jobid)

    # GPU device
    utils.print_out("# Devices visible to TensorFlow: %s" % repr(tf.Session().list_devices()))

    # Random
    random_seed = flags.random_seed
    if random_seed is not None and random_seed > 0:
        utils.print_out("# Set random seed to %d" % random_seed)
        random.seed(random_seed + jobid)
        np.random.seed(random_seed + jobid)

    # Model output directory
    out_dir = flags.out_dir
    if out_dir and not tf.gfile.Exists(out_dir):
        utils.print_out("# Creating output directory %s ..." % out_dir)
        tf.gfile.MakeDirs(out_dir)

    # Load hparams.
    loaded_hparams = False
    if flags.ckpt:  # Try to load hparams from the same directory as ckpt
        ckpt_dir = os.path.dirname(flags.ckpt)
        ckpt_hparams_file = os.path.join(ckpt_dir, "hparams")
        if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path:
            hparams = create_or_load_hparams(
                ckpt_dir, default_hparams, flags.hparams_path, save_hparams=False
            )
            loaded_hparams = True
    if not loaded_hparams:  # Try to load from out_dir
        assert out_dir
        hparams = create_or_load_hparams(
            out_dir, default_hparams, flags.hparams_path, save_hparams=(jobid == 0)
        )

    # Train / Decode
    if flags.inference_input_file:
        # Inference output directory
        trans_file = flags.inference_output_file
        assert trans_file
        trans_dir = os.path.dirname(trans_file)
        if not tf.gfile.Exists(trans_dir):
            tf.gfile.MakeDirs(trans_dir)

        # Inference indices
        hparams.inference_indices = None
        if flags.inference_list:
            (hparams.inference_indices) = [int(token) for token in flags.inference_list.split(",")]

        # Inference
        ckpt = flags.ckpt
        if not ckpt:
            ckpt = tf.train.latest_checkpoint(out_dir)
        inference_fn(ckpt, flags.inference_input_file, trans_file, hparams, num_workers, jobid)

        # Evaluation
        ref_file = flags.inference_ref_file
        if ref_file and tf.gfile.Exists(trans_file):
            for metric in hparams.metrics:
                score = evaluation_utils.evaluate(
                    ref_file, trans_file, metric, hparams.subword_option
                )
                utils.print_out("  %s: %.1f" % (metric, score))
    else:
        # Train
        train_fn(hparams, target_session=target_session)


def main(_):
    default_hparams = create_hparams(FLAGS)
    train_fn = train.train
    inference_fn = inference.inference
    run_main(FLAGS, default_hparams, train_fn, inference_fn)


if __name__ == "__main__":
    nmt_parser = argparse.ArgumentParser()
    add_arguments(nmt_parser)
    FLAGS, unparsed = nmt_parser.parse_known_args()
    validate_arguments(FLAGS)
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
